package com.lw.leetcode.arr.b;

import java.util.Arrays;

/**
 * Created with IntelliJ IDEA.
 * 2333. 最小差值平方和
 *
 * @author liw
 * @version 1.0
 * @date 2022/7/11 11:03
 */
public class MinSumSquareDiff {

    public static void main(String[] args) {
        MinSumSquareDiff test = new MinSumSquareDiff();

        // 579
        int[] as = {1, 2, 3, 4};
        int[] bs = {2, 10, 20, 19};
        int a = 0;
        int b = 0;

        // 515
//        int[] as = {1, 2, 3, 4};
//        int[] bs = {2, 10, 20, 19};
//        int a = 2;
//        int b = 0;

        // 0
//        int[] as = {1,4,10,12};
//        int[] bs = {5,8,6,9};
//        int a = 10;
//        int b = 5;

        // 0
//        int[] as = {10, 10, 10, 11, 5};
//        int[] bs = {1, 0, 6, 6, 1};
//        int a = 11;
//        int b = 27;

        // 3
//        int[] as = {11, 12, 13, 14, 15};
//        int[] bs = {13, 16, 16, 12, 14};
//        int a = 3;
//        int b = 6;
        long l = test.minSumSquareDiff(as, bs, a, b);
        System.out.println(l);
    }

    public long minSumSquareDiff(int[] nums1, int[] nums2, int k1, int k2) {
        int length = nums1.length;
        for (int i = 0; i < length; i++) {
            nums1[i] = Math.abs(nums1[i] - nums2[i]);
        }
        long sum = 0;
        k1 += k2;
        int st = 0;
        Arrays.sort(nums1);
        int t = nums1[length - 1];
        for (int i = length - 2; i >= 0; i--) {
            int v = nums1[i];
            st = i;
            if (t == v) {
                continue;
            }
            int s = t - v;
            int c = length - i - 1;
            int a = s * c;
            if (k1 > a) {
                k1 -= a;
                t = v;
            } else if (k1 == a) {
                k1 = 0;
                st++;
                break;
            } else {
                st++;
                break;
            }
        }
        int max = nums1[st];
        int c = length - st;
        int mid = st + c - k1 % c;
        max = max - k1 / c;
        int max2 = max - 1;
        if (max <= 0) {
            max = 0;
            max2 = 0;
        }
        for (int i = 0; i < st; i++) {
            sum += (long)nums1[i] * nums1[i];
        }
        sum += (long)max * max * (mid - st);
        sum += (long)max2 * max2 * (length - mid);
        return sum;
    }

}
